from vaemodel import VAE
import pickle
import numpy as np
from trainprogrampredictor import Predictor, create_from_noise
import torch

struc_predictor = Predictor()
struc_predictor.load_state_dict(torch.load("graphnn/predictprogram.pth"))

cond_vae = VAE()
cond_vae.load_state_dict(torch.load("graphnn/graphvae.pth"))
magenta_size = 256
num_symmetries = 22
magents = []
random_magents = []
for i in range(2000):
    x, edge_index, edge_attr = create_from_noise(struc_predictor)
    recon = cond_vae.decoder(x, edge_index.T, edge_attr)

    recon = list(recon[-20:,:magenta_size].detach().numpy())
    magents.append(recon)

pickle.dump(magents, open("pickles/recons.pcl", "wb"))
